using ArgParse, JLD2, Printf, JSON, Dates, IterTools, Random;
using Distributed;

@everywhere include("runit.jl");
@everywhere include("helpers.jl");
@everywhere include("../binary_search.jl");
include("helpers_experiments.jl");

function parse_commandline()
    s = ArgParseSettings();

    @add_arg_table! s begin
        "--save_dir"
            help = "Directory for saving the experiment's data."
            arg_type = String
            default = "experiments/"
        "--data_dir"
            help = "Directory for loading the data."
            arg_type = String
            default = "data/"
        "--seed"
            help = "Seed."
            arg_type = Int64
            default = 42
        "--inst"
            help = "Instance considered."
            arg_type = String
            default = "distinct"
        "--K"
            help = "Number of arms."
            arg_type = Int64
            default = 5
        "--mu1"
            help = "Best arm."
            arg_type = Float64
            default = 0.6
        "--gapmin"
            help = "Min gap."
            arg_type = Float64
            default = 0.1
        "--gapmax"
            help = "Max gap."
            arg_type = Float64
            default = 0.4
        "--B"
            help = "Upper bound."
            arg_type = Float64
            default = 1.0
        "--expe"
            help = "Experiment considered."
            arg_type = String
            default = "random"
        "--Nruns"
            help = "Number of runs of the experiment."
            arg_type = Int64
            default = 8
    end

    parse_args(s);
end

@everywhere function get_rand_instance(param_inst, rng)
    nK = param_inst["nK"];

    μs = param_inst["mu1"] * ones(nK);
    for a in 2:nK
        if param_inst["inst"] == "uniform"
            μs[a] -= (param_inst["gapmax"] - param_inst["gapmin"]) * rand(rng) + param_inst["gapmin"];
        elseif param_inst["inst"] == "distinct"
            val = (param_inst["gapmax"] - param_inst["gapmin"]) * rand(rng) + param_inst["gapmin"];
            while minimum(abs.(μs .- val)) <= 0.001
                val = (param_inst["gapmax"] - param_inst["gapmin"]) * rand(rng) + param_inst["gapmin"];
            end
            μs[a] -= val;
        else
            @error "Not implemented"
        end
    end

    dists = [BernoulliGen(μ, param_inst["B"]) for μ in μs];
    return μs, dists;
end

@everywhere function run_inst(seed, iss, δs, param_inst, Tau_max)
    rng = MersenneTwister(seed);

    # Random instance
    μs, dists = get_rand_instance(param_inst, rng);

    # Pure exploration problem
    pep = BestArm(dists, param_inst["B"]);

    # Stored results
    R = Tuple{Any, Tuple{Int64, Array{Int64,1}, UInt64}}[];

    # Evaluating iss
    for (i, is) in enumerate(iss)
        results = runit(seed, is, pep, δs, Tau_max);
        for result in results
            push!(R, (is, result));
        end
    end

    R;
end

# Parameters
parsed_args = parse_commandline();
save_dir = parsed_args["save_dir"];
data_dir = parsed_args["data_dir"];
seed = parsed_args["seed"];
inst = parsed_args["inst"];
nK = parsed_args["K"];
mu1 = parsed_args["mu1"];
gapmin = parsed_args["gapmin"];
gapmax = parsed_args["gapmax"];
B = parsed_args["B"];
expe = parsed_args["expe"];
Nruns = parsed_args["Nruns"];

# Get Tau_max
Tau_max = 1e6;

# Storing parameters defining the instance
param_inst = Dict("inst" => inst, "nK" => nK, "B" => B, "mu1" => mu1,
                  "gapmin" => gapmin, "gapmax" => gapmax);

# Associated β functions
δs = [0.01];

# Naming files and folder
now_str = Dates.format(now(), "dd-mm_HHhMM");
experiment_name = "exp_random_" * expe * "_" * inst * "_K" * string(nK) * "_N" * string(Nruns);
experiment_dir = save_dir * now_str * ":" * experiment_name * "/";
mkdir(experiment_dir);
open("$(experiment_dir)parsed_args.json","w") do f
    JSON.print(f, parsed_args)
end

# Identification strategy used used on this instance: tuple (sr, rsp)
iss = everybody(expe, ones(nK) / nK);
iss_counters = Dict(is => 1 for is in iss);
iss_index = Dict(is => i for (i, is) in enumerate(iss));

# Run the experiments in parallel
@time _data = pmap(
    (i,) -> run_inst(seed + i, iss, δs, param_inst, Tau_max),
    1:Nruns
);

data = Array{Tuple{Int64, Array{Int64,1}, UInt64}}(undef, length(iss), Nruns * length(δs));
for chunk in _data
    for (is, res) in chunk
        data[iss_index[is], iss_counters[is]] = res;
        iss_counters[is] += 1;
    end
end

# Save everything using JLD2.
@save "$(experiment_dir)$(experiment_name).dat" iss data iss_index δs param_inst Nruns seed;

# Print a summary of the problem we considered
file = "$(experiment_dir)summary_$(experiment_name).txt";
print_rand_summary(δs, iss, data, iss_index, param_inst, Nruns, file);
